--- title: Title keywords: fastai sidebar: home_sidebar nb_path: "04e Embeddings of scRNA Velocity Data.ipynb" ---
from FRED.embed import ManifoldFlowEmbedder
from FRED.trainers import save_embedding_visualization, visualize_points
title = "Flow Neighbor Loss + Diffusion Map Reg"
MFE = ManifoldFlowEmbedder(
embedding_dimension=2,
embedder_shape=[4, 6, 8, 4, 2],
device=device,
sigma=0.5,
flow_strength=0.5,
smoothness_grid=True,
)
loss_weights = {
"reconstruction": 0,
"diffusion map regularization": 1,
"flow neighbor loss": 1,
"smoothness": 0,
"kld": 0,
}
visualization_functions = [
save_embedding_visualization,
visualize_points
]
FREDtrainer = Trainer(FE = MFE, loss_weights=loss_weights, visualization_functions = visualization_functions, device=device, title = title)
from FRED.datasets import double_helix, rnavelo
import scvelo as scv
from FRED.data_processing import dataloader_from_ndarray, ManifoldWithVectorField
from torch.utils.data import DataLoader
X, flow, labels = rnavelo(scv.datasets.simulation(n_obs=1000, switches=[0.2, 0.3, 0.5, 1]))
flow.shape
ds = ManifoldWithVectorField(X, flow, labels, sigma=1, flow_strength=5, dmap_coords_to_use=4, nbhd_strategy="flow neighbors", n_neighbors=5)
dataloader = DataLoader(ds, batch_size=None, shuffle=True)
DC = dataloader.dataset.diff_coords
plt.scatter(DC[:,1],DC[:,2],c=labels)
FREDtrainer.fit(dataloader, n_epochs = 100)
FREDtrainer.visualize_loss()
FREDtrainer.training_gif(duration=150)